/*
 * Copyright (C) 2015 Edward Raff
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package com.jstarcraft.ai.jsat.parameters;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.ClassificationModelEvaluation;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.regression.RegressionModelEvaluation;
import com.jstarcraft.ai.jsat.regression.Regressor;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * Random Search is a simple method for tuning the parameters of a
 * classification or regression algorithm. Each parameter is given a
 * distribution that represents the values of interest, and trials are done by
 * randomly sampling each parameter from their respective distributions.
 * Compared to {@link GridSearch} this method does better when lots of values
 * are to be tested or when 2 or more parameters are to be evaluated. <br>
 * The model it takes must implement the {@link Parameterized} interface. By
 * default, no parameters are selected for optimizations. This is because
 * parameters value ranges are often algorithm specific. As such, the user must
 * specify the parameters and the values to test using the <tt>addParameter</tt>
 * methods.
 * 
 * See : Bergstra, J., & Bengio, Y. (2012). <i>Random Search for Hyper-Parameter
 * Optimization</i>. Journal ofMachine Learning Research, 13, 281–305.
 * 
 * @author Edward Raff
 */
public class RandomSearch extends ModelSearch {
    private int trials = 25;

    /**
     * The matching list of distributions we will test.
     */
    private List<Distribution> searchValues;

    /**
     * Creates a new GridSearch to tune the specified parameters of a regression
     * model. The parameters still need to be specified by calling
     * {@link #addParameter(jsat.parameters.DoubleParameter, double[]) }
     *
     * @param baseRegressor the regressor to tune the parameters of
     * @param folds         the number of folds of cross-validation to perform to
     *                      evaluate each combination of parameters
     * @throws FailedToFitException if the base regressor does not implement
     *                              {@link Parameterized}
     */
    public RandomSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        searchValues = new ArrayList<Distribution>();
    }

    /**
     * Creates a new GridSearch to tune the specified parameters of a classification
     * model. The parameters still need to be specified by calling
     * {@link #addParameter(jsat.parameters.DoubleParameter, double[]) }
     * 
     * @param baseClassifier the classifier to tune the parameters of
     * @param folds          the number of folds of cross-validation to perform to
     *                       evaluate each combination of parameters
     * @throws FailedToFitException if the base classifier does not implement
     *                              {@link Parameterized}
     */
    public RandomSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        searchValues = new ArrayList<Distribution>();
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public RandomSearch(RandomSearch toCopy) {
        super(toCopy);
        this.trials = toCopy.trials;
        this.searchValues = new ArrayList<Distribution>(toCopy.searchValues.size());
        for (Distribution d : toCopy.searchValues)
            this.searchValues.add(d.clone());
    }

    /**
     * This method will automatically populate the search space with parameters
     * based on which Parameter objects return non-null distributions.<br>
     * <br>
     * Note, using this method with Cross Validation has the potential for
     * over-estimating the accuracy of results if the data set is actually used to
     * for parameter guessing.<br>
     * <br>
     * It is possible for this method to return 0, indicating that no default
     * parameters could be found. The intended interpretation is that there are no
     * parameters that you <i>need</i> to tune to get good performance from the
     * given model. Though there will be cases where the author has simply missed a
     * class.
     *
     *
     * @param data the data set to get parameter estimates from
     * @return the number of parameters added
     */
    public int autoAddParameters(DataSet data) {
        Parameterized obj;
        if (baseClassifier != null)
            obj = (Parameterized) baseClassifier;
        else
            obj = (Parameterized) baseRegressor;
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter) param).getGuess(data);
                if (dist != null) {
                    addParameter((DoubleParameter) param, dist);
                    totalParms++;
                }
            } else if (param instanceof IntParameter) {
                dist = ((IntParameter) param).getGuess(data);
                if (dist != null) {
                    addParameter((IntParameter) param, dist);
                    totalParms++;
                }
            }
        }

        return totalParms;
    }

    /**
     * Sets the number of trials or samples that will be taken. This value is the
     * number of models that will be trained and evaluated for their performance
     * 
     * @param trials the number of models to build and evaluate
     */
    public void setTrials(int trials) {
        if (trials < 1)
            throw new IllegalArgumentException("number of trials must be positive, not " + trials);
        this.trials = trials;
    }

    /**
     * 
     * @return the number of models that will be built to evaluate
     */
    public int getTrials() {
        return trials;
    }

    /**
     * Adds a new double parameter to be altered for the model being tuned.
     *
     * @param param               the model parameter
     * @param initialSearchValues the distribution to sample from for this parameter
     */
    public void addParameter(DoubleParameter param, Distribution dist) {
        if (param == null)
            throw new IllegalArgumentException("null not allowed for parameter");
        searchParams.add(param);
        searchValues.add(dist.clone());
    }

    /**
     * Adds a new double parameter to be altered for the model being tuned.
     *
     * @param param               the model parameter
     * @param initialSearchValues the distribution to sample from for this parameter
     */
    public void addParameter(IntParameter param, Distribution dist) {
        if (param == null)
            throw new IllegalArgumentException("null not allowed for parameter");
        searchParams.add(param);
        searchValues.add(dist.clone());
    }

    /**
     * Adds a new parameter to be altered for the model being tuned.
     *
     * @param name                the name of the parameter
     * @param initialSearchValues the values to try for the specified parameter
     */
    public void addParameter(String name, Distribution dist) {
        Parameter param = getParameterByName(name);

        if (param instanceof DoubleParameter)
            addParameter((DoubleParameter) param, dist);
        else if (param instanceof IntParameter)
            addParameter((IntParameter) param, dist);
        else
            throw new IllegalArgumentException("Parameter " + name + " is not for double or int values");
    }

    @Override
    public void train(final ClassificationDataSet dataSet, final boolean parallel) {
        final PriorityQueue<ClassificationModelEvaluation> bestModels = new PriorityQueue<>(folds, (ClassificationModelEvaluation t, ClassificationModelEvaluation t1) -> {
            double v0 = t.getScoreStats(classificationTargetScore).getMean();
            double v1 = t1.getScoreStats(classificationTargetScore).getMean();
            int order = classificationTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });

        /**
         * Each model is set to have different combination of parameters. We then train
         * each model to determine the best one.
         */
        final List<Classifier> paramsToEval = new ArrayList<Classifier>();

        Random rand = RandomUtil.getRandom();
        for (int trial = 0; trial < trials; trial++) {
            for (int i = 0; i < searchParams.size(); i++) {
                double sampledValue = searchValues.get(i).invCdf(rand.nextDouble());

                Parameter param = searchParams.get(i);
                if (param instanceof DoubleParameter)
                    ((DoubleParameter) param).setValue(sampledValue);
                else if (param instanceof IntParameter)
                    ((IntParameter) param).setValue((int) Math.round(sampledValue));
            }

            paramsToEval.add(baseClassifier.clone());
        }

        // if we are doing our CV splits ahead of time, get them done now
        final List<ClassificationDataSet> preFolded;

        /**
         * Pre-combine our training combinations so that any caching can be re-used
         */
        final List<ClassificationDataSet> trainCombinations;

        if (reuseSameCVFolds) {
            preFolded = dataSet.cvSet(folds);
            trainCombinations = new ArrayList<>(preFolded.size());
            for (int i = 0; i < preFolded.size(); i++)
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (indx) -> {
            Classifier c = paramsToEval.get(indx);
            ClassificationModelEvaluation cme = new ClassificationModelEvaluation(c, dataSet, !trainModelsInParallel && parallel);
            cme.addScorer(classificationTargetScore.clone());

            if (reuseSameCVFolds)
                cme.evaluateCrossValidation(preFolded, trainCombinations);
            else
                cme.evaluateCrossValidation(folds);

            synchronized (bestModels) {
                bestModels.add(cme);
            }
        });

        Classifier bestClassifier = bestModels.peek().getClassifier();// Just re-train it on the whole set
        if (trainFinalModel)
            bestClassifier.train(dataSet, parallel);
        trainedClassifier = bestClassifier;
    }

    @Override
    public void train(final RegressionDataSet dataSet, final boolean parallel) {
        final PriorityQueue<RegressionModelEvaluation> bestModels = new PriorityQueue<>(folds, (RegressionModelEvaluation t, RegressionModelEvaluation t1) -> {
            double v0 = t.getScoreStats(regressionTargetScore).getMean();
            double v1 = t1.getScoreStats(regressionTargetScore).getMean();
            int order = regressionTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });

        /**
         * Each model is set to have different combination of parameters. We then train
         * each model to determine the best one.
         */
        final List<Regressor> paramsToEval = new ArrayList<>();

        Random rand = RandomUtil.getRandom();
        for (int trial = 0; trial < trials; trial++) {
            for (int i = 0; i < searchParams.size(); i++) {
                double sampledValue = searchValues.get(i).invCdf(rand.nextDouble());

                Parameter param = searchParams.get(i);
                if (param instanceof DoubleParameter)
                    ((DoubleParameter) param).setValue(sampledValue);
                else if (param instanceof IntParameter)
                    ((IntParameter) param).setValue((int) Math.round(sampledValue));
            }

            paramsToEval.add(baseRegressor.clone());
        }

        // if we are doing our CV splits ahead of time, get them done now
        final List<RegressionDataSet> preFolded;

        /**
         * Pre-combine our training combinations so that any caching can be re-used
         */
        final List<RegressionDataSet> trainCombinations;

        if (reuseSameCVFolds) {
            preFolded = dataSet.cvSet(folds);
            trainCombinations = new ArrayList<>(preFolded.size());
            for (int i = 0; i < preFolded.size(); i++)
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (indx) -> {
            Regressor r = paramsToEval.get(indx);
            RegressionModelEvaluation cme = new RegressionModelEvaluation(r, dataSet, !trainModelsInParallel && parallel);
            cme.addScorer(regressionTargetScore.clone());

            if (reuseSameCVFolds)
                cme.evaluateCrossValidation(preFolded, trainCombinations);
            else
                cme.evaluateCrossValidation(folds);

            synchronized (bestModels) {
                bestModels.add(cme);
            }
        });

        Regressor bestRegressor = bestModels.peek().getRegressor();// Just re-train it on the whole set
        if (trainFinalModel)
            bestRegressor.train(dataSet, parallel);
        trainedRegressor = bestRegressor;
    }

    @Override
    public RandomSearch clone() {
        return new RandomSearch(this);
    }

}
